import bm25s
import Stemmer
from tqdm import tqdm
import json
import numpy as np
from rank_bm25 import BM25Okapi

def calculate_bm25_scores(train_embs,dev_embs):
    if(not isinstance(dev_embs, list)):
        dev_embs = [dev_embs]
    corpus = train_embs
    tokenized_corpus = [doc.split(" ") for doc in corpus]
    bm25 = BM25Okapi(tokenized_corpus)
    scores = []
    for query in tqdm(dev_embs):
        tokenized_query = query.split(" ")
        doc_scores = bm25.get_scores(tokenized_query)
        scores.append(doc_scores)
    if len(dev_embs) == 1:
        return np.array(scores[0])
    else:
        return np.vstack(scores)
def calculate_bm25_scores_bm25s(train_embs,dev_embs):
    
    corpus = train_embs
    stemmer = Stemmer.Stemmer("english")

    corpus_tokens = bm25s.tokenize(corpus, stopwords="en",stemmer=stemmer)

    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    scores = np.array([])
    for query in tqdm(dev_embs):
        query_tokens = bm25s.tokenize(query,stemmer=stemmer)

        results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=len(train_embs))
        id_to_index = {doc_id: index for index, doc_id in enumerate(train_embs)}

        original_order_score = [0] * len(train_embs)
        for i in range(results.shape[1]):
            doc_id = results[0, i]
            score = scores[0, i]
            original_index = id_to_index[doc_id]
            original_order_score[original_index] = score

        scores = np.append(scores, original_order_score)
        
        scores = scores.reshape(-1, len(train_embs))
        scores = scores[0]
    return scores
if __name__ == "__main__":
    with open("data/gsm8k/gsm8k_train.json", 'r', encoding='utf-8') as file:
        ds_train = json.load(file)
    with open("data/gsm8k/gsm8k_test.json", 'r', encoding='utf-8') as file:
        ds_test = json.load(file)

    if isinstance(ds_train, list) and all(isinstance(item, dict) for item in ds_train) and isinstance(ds_test, list) and all(isinstance(item, dict) for item in ds_test):
        train_embs = [data["question"] for data in ds_train]
        dev_embs = [data["question"] for data in ds_test]
    dev_embs = dev_embs[:10]
    for dev_emb in dev_embs:
        dev_emb = [dev_emb]
        scores = calculate_bm25_scores_bm25s(train_embs=train_embs,dev_embs=dev_emb)
    
        print(scores)